#!/usr/bin/python

import sys
sys.path = ["."]+sys.path
import pdb
from pdb import pm

import re
import random as pyrandom
from pylab import *
import tables
import ocrolib; reload(ocrolib)
from collections import Counter
import ocrolib
from ocrolib import patrec
from scipy.spatial import distance
from ocrolib.patrec import Dataset,cshow,showgrid
from ocrolib.ligatures import lig

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-N','--maxtotal',type=int,default=10000000000,help="max # of samples used")
parser.add_argument('-t','--testset',type=int,default=-1,help="use testset sequence t (-1=use all samples)")
parser.add_argument('-T','--maxtrain',type=int,default=100000,help="max # of training samples per classifier")
parser.add_argument('-d','--data',default="chardata.h5",help="data file")
parser.add_argument('-s','--splitter',default="split.splitter",help="input model to be updated")
parser.add_argument('-o','--output',default="trained.cmodel",help="output with per-leaf classifiers")
parser.add_argument('--exec',dest="execute",default=None,help="additional modules to import (e.g., to load additional classifiers)")
parser.add_argument('-C','--classifier',default="patrec.LogisticCmodel()",help="factory for leaf classifier")
parser.add_argument('-D','--debug',action="store_true")
parser.add_argument('-Q','--parallel',type=int,default=0,help="number of CPUs to use for training")
parser.add_argument('--exclude',default=r"[ _\000-\037]",help="regular expression matching characters to exclude")
args = parser.parse_args()
#args = parser.parse_args(["-s","split100k.smodel","-N","100000"])
cfactory = eval("lambda:"+args.classifier)

if args.execute is not None:
    print args.execute
    exec args.execute

print "loading splitter"
import cPickle
with open(args.splitter) as stream:
    splitter = cPickle.load(stream)
print "got",splitter
print "#splits",splitter.nclusters()
print "excluding",args.exclude

def testset(i):
    """Quick check for whether the sample is in the test set."""
    if args.testset<0: return 0
    return ocrolib.testset(i,sequence=args.testset)

# load the dataset and find out which buckets each sample
# goes into; we also get rid of samples that are in excluded
# classes by assigning it to the special bucket '-1'

with tables.openFile(args.data,"r") as h5:
    print "loading dataset"
    N = min(args.maxtotal,len(h5.root.classes))
    patches = Dataset(h5.root.patches,maxsize=N)
    print "splitting"
    splits = []
    for i,v in enumerate(patches):
        if i%10000==0: print i
        if testset(i) or re.search(args.exclude,lig.chr(h5.root.classes[i])):
            splits.append(-1)
        else:
            splits.append(splitter.predict1(v))

# give the user some feedback about cluster distributions

splits = array(splits,'i')
histogram = Counter(splits)

if args.debug:
    counts = sorted(histogram.values(),key=lambda x:-x)
    ion(); gray(); clf()
    yscale('log')
    plot(counts)
    ginput(1,1)

clusters = sorted(histogram.keys())
if len(clusters)<splitter.nclusters():
    print "warning: not all clusters present",len(clusters),splitter.nclusters()

# create a number of "jobs"; we work through these either serially or in parallel

jobs = [(cluster,find(splits==cluster)) for cluster in clusters if cluster>=0]

# the main job performing the splitting for an individual bucket

def process1(job):
    cluster,indexes = job
    if len(indexes)>args.maxtrain:
        indexes = pyrandom.sample(indexes,args.trainmax)
    note = "cluster %4d len %6d"%(cluster,len(indexes))
    with tables.openFile(args.data,"r") as h5:
        # load the classes and training data
        cclasses = [lig.chr(h5.root.classes[i]) for i in indexes]
        patches = Dataset(h5.root.patches)
        data = array([patches[i] for i in indexes],'f')

        # give the user some feedback about what the classes and samples are
        counts = Counter(cclasses).most_common(5)
        cinfo = " / ".join(["%s %s"%(k,v) for k,v in counts])
        note += "    "+cinfo
        if args.debug:
            clf();
            if len(data)>=49: showgrid(patrec.vecsort(pyrandom.sample(data,49)))
            else: showgrid(data)
            suptitle(cinfo)
            ginput(1,0.1)
        assert data.ndim==2

    # now just train the classifier and return it; the `cfactory` expression
    # should take care of any parameters
    print note
    classifier = cfactory()
    classifier.fit(data,cclasses)
    return (cluster,classifier)

if (args.parallel>=0 and args.parallel<2) or args.debug:
    classifiers = []
    for job in jobs:
        classifiers.append(process1(job))
else:
    import multiprocessing
    pool = multiprocessing.Pool(args.parallel)
    classifiers = pool.map_async(process1,jobs)
    pool.close()
    pool.join()
    del pool
    classifiers = classifiers.get()


# nsplitter = patrec.HierarchicalSplitter()
# nsplitter.update(splitter)

cmodel = patrec.LocalCmodel(splitter=splitter)
cmodel.normalizer = patrec.normalizer_normal

for i,c in classifiers:
    cmodel.setCmodel(i,c)

print "writing"
with open(args.output,"w") as stream:
    cPickle.dump(cmodel,stream,2)
